{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "77b3288a-d7ee-4054-8c2e-c1eb09b3fa5c",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# HugeCTR to ONNX Converter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0690ef9f-4601-42e0-bef3-0344865faa54",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Overview\n",
    "\n",
    "To improve compatibility and interoperability with other deep-learning frameworks, we provide a Python module to convert HugeCTR models to ONNX.\n",
    "ONNX serves as an open-source format for AI models.\n",
    "Basically, this converter requires the model graph in JSON, dense model, and sparse models as inputs and saves the converted ONNX model to the specified path.\n",
    "All the required input files can be obtained with HugeCTR training APIs and the whole workflow can be accomplished seamlessly in Python.\n",
    "\n",
    "This notebook demonstrates how to access and use the HugeCTR to ONNX converter.\n",
    "Please make sure that you are familiar with HugeCTR training APIs which will be covered here to ensure the completeness.\n",
    "For more details of the usage of this converter, refer to the HugeCTR to ONNX Converter in the [onnx_converter](https://github.com/NVIDIA-Merlin/HugeCTR/tree/master/onnx_converter) directory of the repository."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da25e22f-b515-471b-8bcd-2dfb9c35639a",
   "metadata": {},
   "source": [
    "## Access the HugeCTR to ONNX Converter\n",
    "\n",
    "Make sure that you start the notebook inside a running 22.05 or later NGC docker container: `nvcr.io/nvidia/merlin/merlin-training:22.05`.\n",
    "The module of the ONNX converter is installed to the path `/usr/local/lib/python3.8/dist-packages`.\n",
    "As for HugeCTR Python interface, a dynamic link to the `hugectr.so` library is installed to the path `/usr/local/hugectr/lib/`.\n",
    "You can access the ONNX converter as well as HugeCTR Python interface anywhere within the container."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a09b68c3-2605-4c43-9cbe-1fa938469ebb",
   "metadata": {},
   "source": [
    "Run the following cell to confirm that the HugeCTR Python interface can be accessed correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ead9bdba-2f27-416f-8aa3-b7c25570817e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import hugectr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a911deed-274f-4db6-9268-e4946a5ea6c9",
   "metadata": {},
   "source": [
    "Run the following cell to confirm that the HugeCTR to ONNX converter can be accessed correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "36a914a7-2246-484f-93c7-725a5a1da59d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import hugectr2onnx"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a41d8931-7430-4ea8-8cdc-ec01d7670c92",
   "metadata": {},
   "source": [
    "## Wide and Deep Model\n",
    "\n",
    "### Download and Preprocess Data\n",
    "\n",
    "1. Download the Criteo dataset using the following command:\n",
    "\n",
    "   ```shell\n",
    "   $ cd ${project_root}/tools\n",
    "   $ wget http://azuremlsampleexperiments.blob.core.windows.net/criteo/day_1.gz\n",
    "   ```\n",
    "   \n",
    "   In preprocessing, we will further reduce the amounts of data to speedup the preprocessing, fill missing values, remove the feature values whose occurrences are very rare, etc. Here we choose pandas preprocessing method to make the dataset ready for HugeCTR training.\n",
    "\n",
    "2. Preprocessing by Pandas using the following command:\n",
    "\n",
    "   ```shell\n",
    "   $ bash preprocess.sh 1 wdl_data pandas 1 1 100\n",
    "   ```\n",
    "   \n",
    "   The first argument represents the dataset postfix. It is 1 here since day_1 is used. The second argument wdl_data is where the preprocessed data is stored. The fourth argument (one after pandas) 1 embodies that the normalization is applied to dense features. The fifth argument 1 means that the feature crossing is applied. The last argument 100 means the number of data files in each file list.\n",
    "   \n",
    "3. Create a soft link to the dataset folder using the following command:\n",
    "\n",
    "   ```shell\n",
    "   $ ln -s ${project_root}/tools/wdl_data ${project_root}/notebooks/wdl_data\n",
    "   ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "953efdd4-711b-4301-9442-8fbd3c70f1e8",
   "metadata": {},
   "source": [
    "### Train the HugeCTR Model\n",
    "\n",
    "We can train fom scratch, dump the model graph to a JSON file, and save the model weights and optimizer states by performing the following with Python APIs:\n",
    "\n",
    "1. Create the solver, reader and optimizer, then initialize the model.\n",
    "2. Construct the model graph by adding input, sparse embedding and dense layers in order.\n",
    "3. Compile the model and have an overview of the model graph.\n",
    "4. Dump the model graph to the JSON file.\n",
    "5. Fit the model, save the model weights and optimizer states implicitly.\n",
    "\n",
    "Please note that the training mode is determined by `repeat_dataset` within `hugectr.CreateSolver`.\n",
    "If it is `True`, the non-epoch mode training is adopted and the maximum iterations should be specified by `max_iter` within `hugectr.Model.fit`.\n",
    "If it is `False`, the epoch-mode training is adopted and the number of epochs should be specified by `num_epochs` within `hugectr.Model.fit`.\n",
    "\n",
    "The optimizer that is used to initialize the model applies to the weights of dense layers, while the optimizer for each sparse embedding layer can be specified independently within `hugectr.SparseEmbedding`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b449f9f3-579c-43ce-b14e-d290e46361e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting wdl_train.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile wdl_train.py\n",
    "import hugectr\n",
    "from mpi4py import MPI\n",
    "solver = hugectr.CreateSolver(max_eval_batches = 300,\n",
    "                              batchsize_eval = 16384,\n",
    "                              batchsize = 16384,\n",
    "                              lr = 0.001,\n",
    "                              vvgpu = [[0]],\n",
    "                              repeat_dataset = True)\n",
    "reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Norm,\n",
    "                                  source = [\"./wdl_data/file_list.txt\"],\n",
    "                                  eval_source = \"./wdl_data/file_list_test.txt\",\n",
    "                                  check_type = hugectr.Check_t.Sum)\n",
    "optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam,\n",
    "                                    update_type = hugectr.Update_t.Global,\n",
    "                                    beta1 = 0.9,\n",
    "                                    beta2 = 0.999,\n",
    "                                    epsilon = 0.0000001)\n",
    "model = hugectr.Model(solver, reader, optimizer)\n",
    "model.add(hugectr.Input(label_dim = 1, label_name = \"label\",\n",
    "                        dense_dim = 13, dense_name = \"dense\",\n",
    "                        data_reader_sparse_param_array = \n",
    "                        [hugectr.DataReaderSparseParam(\"wide_data\", 2, True, 1),\n",
    "                        hugectr.DataReaderSparseParam(\"deep_data\", 1, True, 26)]))\n",
    "model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, \n",
    "                            workspace_size_per_gpu_in_mb = 75,\n",
    "                            embedding_vec_size = 1,\n",
    "                            combiner = \"sum\",\n",
    "                            sparse_embedding_name = \"sparse_embedding2\",\n",
    "                            bottom_name = \"wide_data\",\n",
    "                            optimizer = optimizer))\n",
    "model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, \n",
    "                            workspace_size_per_gpu_in_mb = 1074,\n",
    "                            embedding_vec_size = 16,\n",
    "                            combiner = \"sum\",\n",
    "                            sparse_embedding_name = \"sparse_embedding1\",\n",
    "                            bottom_name = \"deep_data\",\n",
    "                            optimizer = optimizer))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,\n",
    "                            bottom_names = [\"sparse_embedding1\"],\n",
    "                            top_names = [\"reshape1\"],\n",
    "                            leading_dim=416))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,\n",
    "                            bottom_names = [\"sparse_embedding2\"],\n",
    "                            top_names = [\"reshape2\"],\n",
    "                            leading_dim=1))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,\n",
    "                            bottom_names = [\"reshape1\", \"dense\"],\n",
    "                            top_names = [\"concat1\"]))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"concat1\"],\n",
    "                            top_names = [\"fc1\"],\n",
    "                            num_output=1024))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,\n",
    "                            bottom_names = [\"fc1\"],\n",
    "                            top_names = [\"relu1\"]))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,\n",
    "                            bottom_names = [\"relu1\"],\n",
    "                            top_names = [\"dropout1\"],\n",
    "                            dropout_rate=0.5))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"dropout1\"],\n",
    "                            top_names = [\"fc2\"],\n",
    "                            num_output=1024))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,\n",
    "                            bottom_names = [\"fc2\"],\n",
    "                            top_names = [\"relu2\"]))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,\n",
    "                            bottom_names = [\"relu2\"],\n",
    "                            top_names = [\"dropout2\"],\n",
    "                            dropout_rate=0.5))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,\n",
    "                            bottom_names = [\"dropout2\"],\n",
    "                            top_names = [\"fc3\"],\n",
    "                            num_output=1))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,\n",
    "                            bottom_names = [\"fc3\", \"reshape2\"],\n",
    "                            top_names = [\"add1\"]))\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,\n",
    "                            bottom_names = [\"add1\", \"label\"],\n",
    "                            top_names = [\"loss\"]))\n",
    "model.graph_to_json(\"wdl.json\")\n",
    "model.compile()\n",
    "model.summary()\n",
    "model.fit(max_iter = 2300, display = 200, eval_interval = 1000, snapshot = 2000, snapshot_prefix = \"wdl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3b9830e4-8851-43a6-b327-ca53e10026e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================================================Model Init=====================================================\n",
      "[17d09h39m52s][HUGECTR][INFO]: Global seed is 2566812942\n",
      "[17d09h39m53s][HUGECTR][INFO]: Device to NUMA mapping:\n",
      "  GPU 0 ->  node 0\n",
      "\n",
      "[17d09h39m55s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.\n",
      "[17d09h39m55s][HUGECTR][INFO]: Start all2all warmup\n",
      "[17d09h39m55s][HUGECTR][INFO]: End all2all warmup\n",
      "[17d09h39m55s][HUGECTR][INFO]: Using All-reduce algorithm OneShot\n",
      "Device 0: Tesla V100-SXM2-16GB\n",
      "[17d09h39m55s][HUGECTR][INFO]: num of DataReader workers: 12\n",
      "[17d09h39m55s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=6553600\n",
      "[17d09h39m55s][HUGECTR][INFO]: max_vocabulary_size_per_gpu_=5865472\n",
      "[17d09h39m55s][HUGECTR][INFO]: Save the model graph to wdl.json, successful\n",
      "===================================================Model Compile===================================================\n",
      "[17d09h40m31s][HUGECTR][INFO]: gpu0 start to init embedding\n",
      "[17d09h40m31s][HUGECTR][INFO]: gpu0 init embedding done\n",
      "[17d09h40m31s][HUGECTR][INFO]: gpu0 start to init embedding\n",
      "[17d09h40m31s][HUGECTR][INFO]: gpu0 init embedding done\n",
      "[17d09h40m31s][HUGECTR][INFO]: Starting AUC NCCL warm-up\n",
      "[17d09h40m31s][HUGECTR][INFO]: Warm-up done\n",
      "===================================================Model Summary===================================================\n",
      "Label                                   Dense                         Sparse                        \n",
      "label                                   dense                          wide_data,deep_data           \n",
      "(None, 1)                               (None, 13)                              \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "Layer Type                              Input Name                    Output Name                   Output Shape                  \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "DistributedSlotSparseEmbeddingHash      wide_data                     sparse_embedding2             (None, 1, 1)                  \n",
      "DistributedSlotSparseEmbeddingHash      deep_data                     sparse_embedding1             (None, 26, 16)                \n",
      "Reshape                                 sparse_embedding1             reshape1                      (None, 416)                   \n",
      "Reshape                                 sparse_embedding2             reshape2                      (None, 1)                     \n",
      "Concat                                  reshape1,dense                concat1                       (None, 429)                   \n",
      "InnerProduct                            concat1                       fc1                           (None, 1024)                  \n",
      "ReLU                                    fc1                           relu1                         (None, 1024)                  \n",
      "Dropout                                 relu1                         dropout1                      (None, 1024)                  \n",
      "InnerProduct                            dropout1                      fc2                           (None, 1024)                  \n",
      "ReLU                                    fc2                           relu2                         (None, 1024)                  \n",
      "Dropout                                 relu2                         dropout2                      (None, 1024)                  \n",
      "InnerProduct                            dropout2                      fc3                           (None, 1)                     \n",
      "Add                                     fc3,reshape2                  add1                          (None, 1)                     \n",
      "BinaryCrossEntropyLoss                  add1,label                    loss                                                        \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "=====================================================Model Fit=====================================================\n",
      "[17d90h40m31s][HUGECTR][INFO]: Use non-epoch mode with number of iterations: 2300\n",
      "[17d90h40m31s][HUGECTR][INFO]: Training batchsize: 16384, evaluation batchsize: 16384\n",
      "[17d90h40m31s][HUGECTR][INFO]: Evaluation interval: 1000, snapshot interval: 2000\n",
      "[17d90h40m31s][HUGECTR][INFO]: Sparse embedding trainable: 1, dense network trainable: 1\n",
      "[17d90h40m31s][HUGECTR][INFO]: Use mixed precision: 0, scaler: 1.000000, use cuda graph: 1\n",
      "[17d90h40m31s][HUGECTR][INFO]: lr: 0.001000, warmup_steps: 1, decay_start: 0, decay_steps: 1, decay_power: 2.000000, end_lr: 0.000000\n",
      "[17d90h40m31s][HUGECTR][INFO]: Training source file: ./wdl_data/file_list.txt\n",
      "[17d90h40m31s][HUGECTR][INFO]: Evaluation source file: ./wdl_data/file_list_test.txt\n",
      "[17d90h40m35s][HUGECTR][INFO]: Iter: 200 Time(200 iters): 4.140480s Loss: 0.130462 lr:0.001000\n",
      "[17d90h40m39s][HUGECTR][INFO]: Iter: 400 Time(200 iters): 4.015980s Loss: 0.127865 lr:0.001000\n",
      "[17d90h40m43s][HUGECTR][INFO]: Iter: 600 Time(200 iters): 4.015593s Loss: 0.125415 lr:0.001000\n",
      "[17d90h40m47s][HUGECTR][INFO]: Iter: 800 Time(200 iters): 4.014245s Loss: 0.132623 lr:0.001000\n",
      "[17d90h40m51s][HUGECTR][INFO]: Iter: 1000 Time(200 iters): 4.016144s Loss: 0.128454 lr:0.001000\n",
      "[17d90h40m53s][HUGECTR][INFO]: Evaluation, AUC: 0.771219\n",
      "[17d90h40m53s][HUGECTR][INFO]: Eval Time for 300 iters: 1.542828s\n",
      "[17d90h40m57s][HUGECTR][INFO]: Iter: 1200 Time(200 iters): 5.570898s Loss: 0.130795 lr:0.001000\n",
      "[17d90h41m10s][HUGECTR][INFO]: Iter: 1400 Time(200 iters): 4.013813s Loss: 0.135299 lr:0.001000\n",
      "[17d90h41m50s][HUGECTR][INFO]: Iter: 1600 Time(200 iters): 4.016540s Loss: 0.130995 lr:0.001000\n",
      "[17d90h41m90s][HUGECTR][INFO]: Iter: 1800 Time(200 iters): 4.017910s Loss: 0.132997 lr:0.001000\n",
      "[17d90h41m13s][HUGECTR][INFO]: Iter: 2000 Time(200 iters): 4.015891s Loss: 0.119502 lr:0.001000\n",
      "[17d90h41m14s][HUGECTR][INFO]: Evaluation, AUC: 0.778875\n",
      "[17d90h41m14s][HUGECTR][INFO]: Eval Time for 300 iters: 1.545023s\n",
      "[17d90h41m14s][HUGECTR][INFO]: Rank0: Write hash table to file\n",
      "[17d90h41m16s][HUGECTR][INFO]: Rank0: Write hash table to file\n",
      "[17d90h41m18s][HUGECTR][INFO]: Dumping sparse weights to files, successful\n",
      "[17d90h41m18s][HUGECTR][INFO]: Rank0: Write optimzer state to file\n",
      "[17d90h41m18s][HUGECTR][INFO]: Done\n",
      "[17d90h41m18s][HUGECTR][INFO]: Rank0: Write optimzer state to file\n",
      "[17d90h41m18s][HUGECTR][INFO]: Done\n",
      "[17d90h41m20s][HUGECTR][INFO]: Rank0: Write optimzer state to file\n",
      "[17d90h41m20s][HUGECTR][INFO]: Done\n",
      "[17d90h41m21s][HUGECTR][INFO]: Rank0: Write optimzer state to file\n",
      "[17d90h41m21s][HUGECTR][INFO]: Done\n",
      "[17d90h41m32s][HUGECTR][INFO]: Dumping sparse optimzer states to files, successful\n",
      "[17d90h41m32s][HUGECTR][INFO]: Dumping dense weights to file, successful\n",
      "[17d90h41m32s][HUGECTR][INFO]: Dumping dense optimizer states to file, successful\n",
      "[17d90h41m32s][HUGECTR][INFO]: Dumping untrainable weights to file, successful\n",
      "[17d90h41m37s][HUGECTR][INFO]: Iter: 2200 Time(200 iters): 24.012700s Loss: 0.128822 lr:0.001000\n",
      "Finish 2300 iterations with batchsize: 16384 in 67.90s\n"
     ]
    }
   ],
   "source": [
    "!python3 wdl_train.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec3f7028-33eb-46c5-9515-f13c32c118b5",
   "metadata": {},
   "source": [
    "### Convert to ONNX\n",
    "\n",
    "We can convert the trained HugeCTR model to ONNX with a call to `hugectr2onnx.converter.convert`. We can specify whether to convert the sparse embeddings via the flag `convert_embedding` and do not need to provide the sparse models if it is set as `False`. In this notebook, both dense and sparse parts of the HugeCTR model will be converted to ONNX, in order that we can check the correctness of the conversion more easily by comparing inference results based on HugeCTR and ONNX Runtime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fab376f6-23d2-454f-874e-6fc14c822cf6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model is checked!\n",
      "The model is saved at wdl.onnx\n"
     ]
    }
   ],
   "source": [
    "import hugectr2onnx\n",
    "hugectr2onnx.converter.convert(onnx_model_path = \"wdl.onnx\",\n",
    "                            graph_config = \"wdl.json\",\n",
    "                            dense_model = \"wdl_dense_2000.model\",\n",
    "                            convert_embedding = True,\n",
    "                            sparse_models = [\"wdl0_sparse_2000.model\", \"wdl1_sparse_2000.model\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24925559-d0dd-4836-9e9c-6fabec29c698",
   "metadata": {},
   "source": [
    "### Inference with ONNX Runtime and HugeCTR\n",
    "\n",
    "To make inferences with the ONNX runtime, we need to read samples from the data and feed them to the ONNX inference session. Specifically, we need to extract dense features, wide sparse features and deep sparse features from the preprocessed Wide&Deep dataset. To guarantee fair comparison with HugeCTR inference, we will use the first data file within `./wdl_data/file_list_test.txt`, i.e., `./wdl_data/val/sparse_embedding0.data`, and make inference for the same number of samples (should be less than the total number of samples within `./wdl_data/val/sparse_embedding0.data`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9116e718-1d56-4f87-924e-01f1b986baa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import struct\n",
    "import numpy as np\n",
    "def read_samples_for_wdl(data_file, num_samples, key_type=\"I32\", slot_num=27):\n",
    "    key_type_map = {\"I32\": [\"I\", 4], \"I64\": [\"q\", 8]}\n",
    "    with open(data_file, 'rb') as file:\n",
    "        # skip data_header\n",
    "        file.seek(4 + 64 + 1, 0)\n",
    "        batch_label = []\n",
    "        batch_dense = []\n",
    "        batch_wide_data = []\n",
    "        batch_deep_data = []\n",
    "        for _ in range(num_samples):\n",
    "            # one sample\n",
    "            length_buffer = file.read(4) # int\n",
    "            length = struct.unpack('i', length_buffer)\n",
    "            label_buffer = file.read(4) # int\n",
    "            label = struct.unpack('i', label_buffer)[0]\n",
    "            dense_buffer = file.read(4 * 13) # dense_dim * float\n",
    "            dense = struct.unpack(\"13f\", dense_buffer)\n",
    "            keys = []\n",
    "            for _ in range(slot_num):\n",
    "                nnz_buffer = file.read(4) # int\n",
    "                nnz = struct.unpack(\"i\", nnz_buffer)[0]\n",
    "                key_buffer = file.read(key_type_map[key_type][1] * nnz) # nnz * sizeof(key_type)\n",
    "                key = struct.unpack(str(nnz) + key_type_map[key_type][0], key_buffer)\n",
    "                keys += list(key)\n",
    "            check_bit_buffer = file.read(1) # char\n",
    "            check_bit = struct.unpack(\"c\", check_bit_buffer)[0]\n",
    "            batch_label.append(label)\n",
    "            batch_dense.append(dense)\n",
    "            batch_wide_data.append(keys[0:2])\n",
    "            batch_deep_data.append(keys[2:28])\n",
    "    batch_label = np.reshape(np.array(batch_label, dtype=np.float32), newshape=(num_samples, 1))\n",
    "    batch_dense = np.reshape(np.array(batch_dense, dtype=np.float32), newshape=(num_samples, 13))\n",
    "    batch_wide_data = np.reshape(np.array(batch_wide_data, dtype=np.int64), newshape=(num_samples, 1, 2))\n",
    "    batch_deep_data = np.reshape(np.array(batch_deep_data, dtype=np.int64), newshape=(num_samples, 26, 1))\n",
    "    return batch_label, batch_dense, batch_wide_data, batch_deep_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a4a6a4ca-a04e-43de-be46-295dcfe3681f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ONNX Runtime Predicions: [0.02525118 0.00920713 0.0080741  ... 0.02893934 0.02577347 0.1296753 ]\n"
     ]
    }
   ],
   "source": [
    "batch_size = 64\n",
    "num_batches = 100\n",
    "data_file = \"./wdl_data/val/sparse_embedding0.data\" # there are totally 40960 samples\n",
    "onnx_model_path = \"wdl.onnx\"\n",
    "\n",
    "label, dense, wide_data, deep_data = read_samples_for_wdl(data_file, batch_size*num_batches, key_type=\"I32\", slot_num = 27)\n",
    "import onnxruntime as ort\n",
    "sess = ort.InferenceSession(onnx_model_path)\n",
    "res = sess.run(output_names=[sess.get_outputs()[0].name],\n",
    "                  input_feed={sess.get_inputs()[0].name: dense, sess.get_inputs()[1].name: wide_data, sess.get_inputs()[2].name: deep_data})\n",
    "onnx_preds = res[0].reshape((batch_size*num_batches,))\n",
    "print(\"ONNX Runtime Predicions:\", onnx_preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cee77d5e-59e5-4d48-a81a-987b9927433f",
   "metadata": {
    "tags": []
   },
   "source": [
    "We can then make inference based on HugeCTR APIs and compare the prediction results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "53cab65f-d597-4c8b-86d5-53f2b414a8fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[17d09h43m49s][HUGECTR][INFO]: default_emb_vec_value is not specified using default: 0.000000\n",
      "[17d09h43m49s][HUGECTR][INFO]: default_emb_vec_value is not specified using default: 0.000000\n",
      "[17d09h43m53s][HUGECTR][INFO]: Global seed is 3782721491\n",
      "[17d09h43m55s][HUGECTR][INFO]: Peer-to-peer access cannot be fully enabled.\n",
      "[17d09h43m55s][HUGECTR][INFO]: Start all2all warmup\n",
      "[17d09h43m55s][HUGECTR][INFO]: End all2all warmup\n",
      "[17d09h43m55s][HUGECTR][INFO]: Use mixed precision: 0\n",
      "[17d09h43m55s][HUGECTR][INFO]: start create embedding for inference\n",
      "[17d09h43m55s][HUGECTR][INFO]: sparse_input name wide_data\n",
      "[17d09h43m55s][HUGECTR][INFO]: sparse_input name deep_data\n",
      "[17d09h43m55s][HUGECTR][INFO]: create embedding for inference success\n",
      "[17d09h43m55s][HUGECTR][INFO]: Inference stage skip BinaryCrossEntropyLoss layer, replaced by Sigmoid layer\n",
      "HugeCTR Predictions:  [0.02525118 0.00920718 0.00807416 ... 0.0289393  0.02577345 0.12967525]\n"
     ]
    }
   ],
   "source": [
    "dense_model = \"wdl_dense_2000.model\"\n",
    "sparse_models = [\"wdl0_sparse_2000.model\", \"wdl1_sparse_2000.model\"]\n",
    "graph_config = \"wdl.json\"\n",
    "data_source = \"./wdl_data/file_list_test.txt\"\n",
    "import hugectr\n",
    "from mpi4py import MPI\n",
    "from hugectr.inference import InferenceParams, CreateInferenceSession\n",
    "inference_params = InferenceParams(model_name = \"wdl\",\n",
    "                                max_batchsize = batch_size,\n",
    "                                hit_rate_threshold = 0.6,\n",
    "                                dense_model_file = dense_model,\n",
    "                                sparse_model_files = sparse_models,\n",
    "                                device_id = 0,\n",
    "                                use_gpu_embedding_cache = True,\n",
    "                                cache_size_percentage = 0.6,\n",
    "                                i64_input_key = False)\n",
    "inference_session = CreateInferenceSession(graph_config, inference_params)\n",
    "hugectr_preds = inference_session.predict(num_batches, data_source, hugectr.DataReaderType_t.Norm, hugectr.Check_t.Sum)\n",
    "print(\"HugeCTR Predictions: \", hugectr_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c6b8ad91-2967-4762-8294-d203ea8eda85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Min absolute error:  0.0\n",
      "Mean absolute error:  2.3289697e-08\n",
      "Max absolute error:  1.1920929e-07\n"
     ]
    }
   ],
   "source": [
    "print(\"Min absolute error: \", np.min(np.abs(onnx_preds-hugectr_preds)))\n",
    "print(\"Mean absolute error: \", np.mean(np.abs(onnx_preds-hugectr_preds)))\n",
    "print(\"Max absolute error: \", np.max(np.abs(onnx_preds-hugectr_preds)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e7d7a99-126c-4d7a-aa54-f675c69f31e7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## API Signature for hugectr2onnx.converter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5f18194-334f-4479-831b-7c46397ffcf3",
   "metadata": {},
   "source": [
    "```bash\n",
    "NAME\n",
    "    hugectr2onnx.converter\n",
    "\n",
    "FUNCTIONS\n",
    "    convert(onnx_model_path, graph_config, dense_model, convert_embedding=False, sparse_models=[], ntp_file=None, graph_name='hugectr')\n",
    "        Convert a HugeCTR model to an ONNX model\n",
    "        Args:\n",
    "            onnx_model_path: the path to store the ONNX model\n",
    "            graph_config: the graph configuration JSON file of the HugeCTR model\n",
    "            dense_model: the file of the dense weights for the HugeCTR model\n",
    "            convert_embedding: whether to convert the sparse embeddings for the HugeCTR model (optional)\n",
    "            sparse_models: the files of the sparse embeddings for the HugeCTR model (optional)\n",
    "            ntp_file: the file of the non-trainable parameters for the HugeCTR model (optional)\n",
    "            graph_name: the graph name for the ONNX model (optional)\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
